import numpy as np
from typing import List, Union, Dict, Tuple, Optional

from ModelsUtils.Metrics import _get_top_k_names, _precision_at_k_actual_calc, precision_at_k_selected_group
from Utils import logger


def __agent_precision_at_k(top_models, selected_models, k):
    curr_true = top_models[:k]
    curr_res = np.sum(np.isin(selected_models[:k], curr_true)) / k
    return curr_res


def __precision_at_k_from_all(top_models, selected_models, k):
    curr_true = top_models[:k]
    curr_res = np.sum(np.isin(selected_models, curr_true)) / k
    return curr_res


def _calc_accuracy_diff_from_top(all_models_names: Union[List, np.ndarray], y_true: np.ndarray,
                                 selected_models: Union[List, np.ndarray], k: int = 10):
    if k >= len(y_true):
        k = len(y_true) - 1
    if k == 0:
        dict(top_true=y_true[0], top_selected_acc=None, diff=-1)

    top_true = list(sorted(y_true, reverse=True))[:len(selected_models)]
    selected_models_real_acc = y_true[np.isin(all_models_names, selected_models)]
    if k >= len(selected_models):
        top_selected_acc = list(sorted(selected_models_real_acc, reverse=True))
    else:
        top_selected_acc = list(sorted(selected_models_real_acc, reverse=True))
    top_acc_diff = top_true[0] - top_selected_acc[0]
    if len(top_selected_acc) > 3:
        second_top_diff = top_true[1] - top_selected_acc[1]
        third_top_diff = top_true[2] - top_selected_acc[2]
        mean_diff = np.mean(np.array(top_true) - np.array(top_selected_acc))
    else:
        second_top_diff = -1
        third_top_diff = -1
        mean_diff = -1

    if top_true[0] > 1:     # This happens when mult 10 was used in results I want to see diff in range of [0,1]
        top_acc_diff = top_acc_diff/10
        second_top_diff = second_top_diff/10
        third_top_diff = third_top_diff/10
        mean_diff = mean_diff/10

    return dict(top_true=top_true, top_selected_acc=top_selected_acc, top_diff=top_acc_diff,
                second_top_diff=second_top_diff, third_top_diff=third_top_diff, mean_diff=mean_diff)


def optimistic_random_precision_at(all_models_names, y_true, selected_models, selected_preds, k):
    if selected_preds is None:
        res = precision_at_k_selected_group(names=all_models_names, y_true=y_true, selected_models=selected_models,
                                            k=k, random_optimistic=True)
    else:
        models_for_optimistic, preds_for_optimistic = list(zip(*[(curr_name, selected_preds[selected_models.index(curr_name)])
                                                                 for curr_name in all_models_names
                                                                 if curr_name in selected_models]))
        top_selected, selected_eq_loc = _get_top_k_names(models_for_optimistic, preds_for_optimistic, k=k, optimistic=True)
        top_true, _ = _get_top_k_names(all_models_names, y_true, k=k, optimistic=True)
        res = _precision_at_k_actual_calc(sorted_true=top_true, sorted_preds=top_selected, k=k, random_optimistic=True,
                                          equality_loc=selected_eq_loc)
    return res


def eval_agent_selection(all_models_names: Union[List, np.ndarray], y_true: np.ndarray, selected_preds: Optional[List[float]],
                         selected_models: Union[List, np.ndarray], percentiles: Tuple[float, ...],
                         resources: int, k_lst: Tuple[int, ...], specific_model_diff: Optional[str] = None) -> Dict:
    """
    This functions assumes that selected_models and preds are ordered from big to small according to predicted accuracy
    :param all_models_names:
    :param y_true:
    :param selected_preds:
    :param selected_models:
    :param percentiles:
    :param resources:
    :param k_lst:
    :param specific_model_diff:
    :return:
    """
    if len(selected_models) == 0:
        logger().error('eval_agent_selection', ValueError, 'NO Fucking Model was selected WTF ?????')
    top_models, _ = _get_top_k_names(all_models_names, y_true, k=len(all_models_names), optimistic=False)
    selection_res = dict()
    diff_rec = _calc_accuracy_diff_from_top(all_models_names=all_models_names, y_true=y_true,
                                            selected_models=selected_models)
    if specific_model_diff is not None:
        specific_res = _calc_accuracy_diff_from_top(all_models_names=all_models_names, y_true=y_true,
                                                    selected_models=[specific_model_diff])
        diff_rec['specific_diff'] = specific_res['top_diff']

    for percent in percentiles:
        size = int(len(all_models_names)*percent)
        if size == 0:
            size = 1
        precision_at_k_res = __agent_precision_at_k(top_models, selected_models, size)
        precision_at_k_all_res = __precision_at_k_from_all(top_models, selected_models, size)
        optimistic_precision_at_k = optimistic_random_precision_at(all_models_names=all_models_names, y_true=y_true,
                                                                   selected_preds=selected_preds,
                                                                   selected_models=selected_models, k=size)
        selection_res[percent] = {'dataset_size': len(all_models_names), 'selected_size': len(selected_models),
                                  'eval_size': size, 'percent': percent, 'p@k_percent': precision_at_k_res,
                                  'p@k_percent-all': precision_at_k_all_res, 'p@k_percent-opt': optimistic_precision_at_k,
                                  'resources': resources, 'k': -1, **diff_rec}

    for k in k_lst:
        precision_at_k_res = __agent_precision_at_k(top_models, selected_models, k)
        precision_at_k_all_res = __precision_at_k_from_all(top_models, selected_models, k)
        optimistic_precision_at_k = optimistic_random_precision_at(all_models_names=all_models_names, y_true=y_true,
                                                                   selected_preds=selected_preds,
                                                                   selected_models=selected_models, k=k)
        selection_res[k] = {'dataset_size': len(all_models_names), 'selected_size': len(selected_models),
                            'eval_size': k, 'percent': -1, 'p@k': precision_at_k_res, 'p@k-all': precision_at_k_all_res,
                            'p@k-opt': optimistic_precision_at_k, 'resources': resources, 'k': k, **diff_rec}

    return selection_res
